{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "texture_nca_grafting.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qbyJuqANoinE"
      },
      "source": [
        "# Neural CA Grafting\n",
        "\n",
        "This is a modiffied version of [Self-Organizing Textures NCA](https://distill.pub/selforg/2021/textures) that was created for the [Neural CA Grafting video tutorial](https://youtu.be/Tbe-41HowwY)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "INX-7rI-cy0I"
      },
      "source": [
        "*Copyright 2021 Google LLC*\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FR8YNR-g9JXA",
        "cellView": "form"
      },
      "source": [
        "#@title Imports and Notebook Utilities\n",
        "%tensorflow_version 2.x\n",
        "\n",
        "import os\n",
        "import io\n",
        "import PIL.Image, PIL.ImageDraw\n",
        "import base64\n",
        "import zipfile\n",
        "import json\n",
        "import requests\n",
        "import numpy as np\n",
        "import matplotlib.pylab as pl\n",
        "import glob\n",
        "\n",
        "from IPython.display import Image, HTML, clear_output\n",
        "from tqdm import tqdm_notebook, tnrange\n",
        "\n",
        "os.environ['FFMPEG_BINARY'] = 'ffmpeg'\n",
        "import moviepy.editor as mvp\n",
        "from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter\n",
        "\n",
        "\n",
        "def imread(url, max_size=None, mode=None):\n",
        "  if url.startswith(('http:', 'https:')):\n",
        "    # wikimedia requires a user agent\n",
        "    headers = {\n",
        "      \"User-Agent\": \"Requests in Colab/0.0 (https://colab.research.google.com/; no-reply@google.com) requests/0.0\"\n",
        "    }\n",
        "    r = requests.get(url, headers=headers)\n",
        "    f = io.BytesIO(r.content)\n",
        "  else:\n",
        "    f = url\n",
        "  img = PIL.Image.open(f)\n",
        "  if max_size is not None:\n",
        "    img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)\n",
        "  if mode is not None:\n",
        "    img = img.convert(mode)\n",
        "  img = np.float32(img)/255.0\n",
        "  return img\n",
        "\n",
        "def np2pil(a):\n",
        "  if a.dtype in [np.float32, np.float64]:\n",
        "    a = np.uint8(np.clip(a, 0, 1)*255)\n",
        "  return PIL.Image.fromarray(a)\n",
        "\n",
        "def imwrite(f, a, fmt=None):\n",
        "  a = np.asarray(a)\n",
        "  if isinstance(f, str):\n",
        "    fmt = f.rsplit('.', 1)[-1].lower()\n",
        "    if fmt == 'jpg':\n",
        "      fmt = 'jpeg'\n",
        "    f = open(f, 'wb')\n",
        "  np2pil(a).save(f, fmt, quality=95)\n",
        "\n",
        "def imencode(a, fmt='jpeg'):\n",
        "  a = np.asarray(a)\n",
        "  if len(a.shape) == 3 and a.shape[-1] == 4:\n",
        "    fmt = 'png'\n",
        "  f = io.BytesIO()\n",
        "  imwrite(f, a, fmt)\n",
        "  return f.getvalue()\n",
        "\n",
        "def im2url(a, fmt='jpeg'):\n",
        "  encoded = imencode(a, fmt)\n",
        "  base64_byte_string = base64.b64encode(encoded).decode('ascii')\n",
        "  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string\n",
        "\n",
        "def imshow(a, fmt='jpeg'):\n",
        "  display(Image(data=imencode(a, fmt)))\n",
        "\n",
        "def tile2d(a, w=None):\n",
        "  a = np.asarray(a)\n",
        "  if w is None:\n",
        "    w = int(np.ceil(np.sqrt(len(a))))\n",
        "  th, tw = a.shape[1:3]\n",
        "  pad = (w-len(a))%w\n",
        "  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')\n",
        "  h = len(a)//w\n",
        "  a = a.reshape([h, w]+list(a.shape[1:]))\n",
        "  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))\n",
        "  return a\n",
        "\n",
        "def zoom(img, scale=4):\n",
        "  img = np.repeat(img, scale, 0)\n",
        "  img = np.repeat(img, scale, 1)\n",
        "  return img\n",
        "\n",
        "class VideoWriter:\n",
        "  def __init__(self, filename='_autoplay.mp4', fps=30.0, **kw):\n",
        "    self.writer = None\n",
        "    self.params = dict(filename=filename, fps=fps, **kw)\n",
        "\n",
        "  def add(self, img):\n",
        "    img = np.asarray(img)\n",
        "    if self.writer is None:\n",
        "      h, w = img.shape[:2]\n",
        "      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)\n",
        "    if img.dtype in [np.float32, np.float64]:\n",
        "      img = np.uint8(img.clip(0, 1)*255)\n",
        "    if len(img.shape) == 2:\n",
        "      img = np.repeat(img[..., None], 3, -1)\n",
        "    self.writer.write_frame(img)\n",
        "\n",
        "  def close(self):\n",
        "    if self.writer:\n",
        "      self.writer.close()\n",
        "\n",
        "  def __enter__(self):\n",
        "    return self\n",
        "\n",
        "  def __exit__(self, *kw):\n",
        "    self.close()\n",
        "    if self.params['filename'] == '_autoplay.mp4':\n",
        "      self.show(loop=True)\n",
        "\n",
        "  def show(self, **kw):\n",
        "      self.close()\n",
        "      fn = self.params['filename']\n",
        "      display(mvp.ipython_display(fn, **kw))\n",
        "\n",
        "!nvidia-smi -L"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C1hkDT38hSaP"
      },
      "source": [
        "import torch\n",
        "import torchvision.models as models\n",
        "\n",
        "torch.set_default_tensor_type('torch.cuda.FloatTensor')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vFqmv9w5tPQB",
        "cellView": "form"
      },
      "source": [
        "#@title VGG16-based Style Model\n",
        "vgg16 = models.vgg16(pretrained=True).features\n",
        "\n",
        "def calc_styles(imgs):\n",
        "  style_layers = [1, 6, 11, 18, 25]  \n",
        "  mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None]\n",
        "  std = torch.tensor([0.229, 0.224, 0.225])[:,None,None]\n",
        "  x = (imgs-mean) / std\n",
        "  grams = []\n",
        "  for i, layer in enumerate(vgg16[:max(style_layers)+1]):\n",
        "    x = layer(x)\n",
        "    if i in style_layers:\n",
        "      h, w = x.shape[-2:]\n",
        "      y = x.clone()  # workaround for pytorch in-place modification bug(?)\n",
        "      gram = torch.einsum('bchw, bdhw -> bcd', y, y) / (h*w)\n",
        "      grams.append(gram)\n",
        "  return grams\n",
        "\n",
        "def style_loss(grams_x, grams_y):\n",
        "  loss = 0.0\n",
        "  for x, y in zip(grams_x, grams_y):\n",
        "    loss = loss + (x-y).square().mean()\n",
        "  return loss\n",
        "\n",
        "def to_nchw(img):\n",
        "  img = torch.as_tensor(img)\n",
        "  if len(img.shape) == 3:\n",
        "    img = img[None,...]\n",
        "  return img.permute(0, 3, 1, 2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ewXnAPrbHx0-"
      },
      "source": [
        "#@title Minimalistic Neural CA\n",
        "ident = torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]])\n",
        "sobel_x = torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]])\n",
        "lap = torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])\n",
        "\n",
        "def perchannel_conv(x, filters):\n",
        "  '''filters: [filter_n, h, w]'''\n",
        "  b, ch, h, w = x.shape\n",
        "  y = x.reshape(b*ch, 1, h, w)\n",
        "  y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')\n",
        "  y = torch.nn.functional.conv2d(y, filters[:,None])\n",
        "  return y.reshape(b, -1, h, w)\n",
        "\n",
        "def perception(x):\n",
        "  filters = torch.stack([ident, sobel_x, sobel_x.T, lap])\n",
        "  return perchannel_conv(x, filters)\n",
        "\n",
        "class CA(torch.nn.Module):\n",
        "  def __init__(self, chn=12, hidden_n=96):\n",
        "    super().__init__()\n",
        "    self.chn = chn\n",
        "    self.w1 = torch.nn.Conv2d(chn*4, hidden_n, 1)\n",
        "    self.w2 = torch.nn.Conv2d(hidden_n, chn, 1, bias=False)\n",
        "    self.w2.weight.data.zero_()\n",
        "\n",
        "  def forward(self, x, update_rate=0.5):\n",
        "    y = perception(x)\n",
        "    y = self.w2(torch.relu(self.w1(y)))\n",
        "    b, c, h, w = y.shape\n",
        "    udpate_mask = (torch.rand(b, 1, h, w)+update_rate).floor()\n",
        "    return x+y*udpate_mask\n",
        "\n",
        "  def seed(self, n, sz=128):\n",
        "    return torch.zeros(n, self.chn, sz, sz)\n",
        "\n",
        "def to_rgb(x):\n",
        "  return x[...,:3,:,:]+0.5\n",
        "\n",
        "param_n = sum(p.numel() for p in CA().parameters())\n",
        "print('CA param count:', param_n)\n",
        "if not os.path.exists('init.pt'):\n",
        "  torch.save(CA(), 'init.pt')\n",
        "  print('saved init.pt')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V8Af9R9mbtsn"
      },
      "source": [
        "#@title targets {vertical-output: true}\n",
        "style_urls = {\n",
        "  'dots': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/dotted/dotted_0090.jpg',\n",
        "  'chess': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/chequered/chequered_0121.jpg',\n",
        "  'bubbles': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/bubbly/bubbly_0101.jpg',\n",
        "}\n",
        "\n",
        "imgs = [imread(url, max_size=128) for url in style_urls.values()]\n",
        "imshow(np.hstack(imgs))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5NNHH6fSsLvv"
      },
      "source": [
        "#@title Target image {vertical-output: true}\n",
        "target_name = 'bubbles'\n",
        "style_img = imread(style_urls[target_name], max_size=128)\n",
        "with torch.no_grad():\n",
        "  target_style = calc_styles(to_nchw(style_img))\n",
        "imshow(style_img)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CjEkkadoH35s"
      },
      "source": [
        "#@title Loading pretrained models\n",
        "!wget -nc https://github.com/google-research/self-organising-systems/raw/master/assets/grafting_nca.zip && unzip grafting_nca.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9Cd1uMD3cZw3"
      },
      "source": [
        "#@title setup training\n",
        "parent = 'dots' # replace this with 'init' to train from scratch\n",
        "model_name = parent+'_'+target_name\n",
        "ca = torch.load(parent+'.pt')\n",
        "opt = torch.optim.Adam(ca.parameters(), 1e-3)\n",
        "lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [2000], 0.3)\n",
        "loss_log = []\n",
        "with torch.no_grad():\n",
        "  pool = ca.seed(256)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3pxjdESnYlIC"
      },
      "source": [
        "#@title training loop {vertical-output: true}\n",
        "for i in range(4000):\n",
        "  with torch.no_grad():\n",
        "    batch_idx = np.random.choice(len(pool), 4, replace=False)\n",
        "    x = pool[batch_idx]\n",
        "    if i%8 == 0:\n",
        "      x[:1] = ca.seed(1)\n",
        "  step_n = np.random.randint(32, 96)\n",
        "  for k in range(step_n):\n",
        "    x = ca(x)\n",
        "  imgs = to_rgb(x)\n",
        "  styles = calc_styles(imgs)\n",
        "  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum()\n",
        "  loss = style_loss(styles, target_style)+overflow_loss\n",
        "  with torch.no_grad():\n",
        "    loss.backward()\n",
        "    for p in ca.parameters():\n",
        "      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients \n",
        "    opt.step()\n",
        "    opt.zero_grad()\n",
        "    lr_sched.step()\n",
        "    pool[batch_idx] = x                # update pool\n",
        "    \n",
        "    loss_log.append(loss.item())\n",
        "    if i%32==0:\n",
        "      clear_output(True)\n",
        "      pl.plot(loss_log, '.', alpha=0.1)\n",
        "      pl.yscale('log')\n",
        "      pl.ylim(np.min(loss_log), loss_log[0])\n",
        "      pl.show()\n",
        "      imgs = to_rgb(x).permute([0, 2, 3, 1]).cpu()\n",
        "      imshow(np.hstack(imgs))\n",
        "      torch.save(ca, model_name+'.pt')\n",
        "    if i%10 == 0:\n",
        "      print('\\rstep_n:', len(loss_log),\n",
        "        ' loss:', loss.item(), \n",
        "        ' lr:', lr_sched.get_lr()[0], end='')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZTVkyfKc_2mF"
      },
      "source": [
        "#@title NCA video {vertical-output: true}\n",
        "def show_ca(ca):\n",
        "  with VideoWriter() as vid, torch.no_grad():\n",
        "    x = ca.seed(1, 256)\n",
        "    for k in tnrange(300, leave=False):\n",
        "      step_n = min(2**(k//30), 16)\n",
        "      for i in range(step_n):\n",
        "        x[:] = ca(x)\n",
        "      img = to_rgb(x[0]).permute(1, 2, 0).cpu()\n",
        "      vid.add(zoom(img, 2))\n",
        "\n",
        "show_ca(torch.load('dots.pt'))\n",
        "show_ca(torch.load('chess.pt'))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "epE536uFF9eL"
      },
      "source": [
        "W = 256\n",
        "with torch.no_grad():\n",
        "  r = torch.linspace(-1, 1, W)**2\n",
        "  r = (r+r[:,None]).sqrt()\n",
        "  mask = ((0.6-r)*8.0).sigmoid()\n",
        "  pl.contourf(mask.cpu())\n",
        "  pl.colorbar()\n",
        "  pl.axis('equal')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lvxC_jidyXO1"
      },
      "source": [
        "ca1 = torch.load('dots_chess.pt')\n",
        "ca2 = torch.load('dots_bubbles.pt')\n",
        "with VideoWriter() as vid, torch.no_grad():\n",
        "  x = torch.zeros([1, ca1.chn, W, W])\n",
        "  for i in tnrange(600):\n",
        "    for k in range(8):\n",
        "      x1, x2 = ca1(x), ca2(x)\n",
        "      x = x1 + (x2-x1)*mask\n",
        "    img = to_rgb(x[0]).permute(1, 2, 0).cpu()\n",
        "    vid.add(zoom(img, 2))\n",
        "\n",
        "    "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G4O-AF_xn07Z"
      },
      "source": [
        "def show_pair(fn1, fn2, W=512):\n",
        "  ca1 = torch.load(fn1)\n",
        "  ca2 = torch.load(fn2)\n",
        "  with VideoWriter() as vid, torch.no_grad():\n",
        "    x = torch.zeros([1, ca1.chn, 128, W])\n",
        "    mask = 0.5-0.5*torch.linspace(0, 2.0*np.pi, W).cos()\n",
        "    for i in tnrange(300, leave=False):\n",
        "      for k in range(8):\n",
        "        x1, x2 = ca1(x), ca2(x)\n",
        "        x = x1 + (x2-x1)*mask\n",
        "      img = to_rgb(x[0]).permute(1, 2, 0).cpu()\n",
        "      vid.add(zoom(img, 2))\n",
        "\n",
        "show_pair('chess.pt', 'dots_bubbles.pt')\n",
        "show_pair('dots_chess.pt', 'dots_bubbles.pt')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JVSsnq-6pEsv"
      },
      "source": [
        "show_pair('dots.pt', 'chess.pt')\n",
        "show_pair('dots.pt', 'dots_chess.pt')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D-RKDUEXyROe"
      },
      "source": [
        "show_pair('dots.pt', 'dots_bubbles.pt')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4TrvXx_SAk95"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}